# -*- coding: utf-8 -*- 
"""
========================================================================================================================
@Author: 孟颖
@email: 652044581@qq.com
@date: 2023/4/20 10:19
@desc: 加密解密模块
========================================================================================================================
"""
import os
from pathlib import Path, PurePath
from Crypto import Random
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from Crypto.PublicKey import RSA
import base64, hashlib
from Crypto.Cipher import PKCS1_v1_5
from gmssl import sm2, sm4
from enum import Enum


class HashCipher:

    @staticmethod
    def md5(message):
        """
        Returns the MD5 hash of the message
        """
        return hashlib.md5(message.encode()).hexdigest()

    @staticmethod
    def sha1(message):
        """
        Returns the SHA1 hash of the message
        """
        return hashlib.sha1(message.encode()).hexdigest()

    @staticmethod
    def sha256(message):
        """
        Returns the SHA256 hash of the message
        """
        return hashlib.sha256(message.encode()).hexdigest()

    @staticmethod
    def sha512(message):
        """
        Returns the SHA512 hash of the message
        """
        return hashlib.sha512(message.encode()).hexdigest()


class Base64Cipher:
    @staticmethod
    def bytes_to_base64(bytes_data):
        """
        Converts bytes to base64 string
        """
        return base64.b64encode(bytes_data).decode('utf-8')

    @staticmethod
    def base64_to_bytes(base64_string):
        """
        Converts base64 string to bytes
        """
        return base64.b64decode(base64_string.encode('utf-8'))


class AESCipherCBC:
    """AES， CBC模式加密"""

    def __init__(self, key, iv):
        """定义初始化加密和解密key和iv偏移"""
        self.key = key.encode('utf-8')
        self.iv = iv

    def encrypt(self, plaintext):
        """AES加密"""
        cipher = AES.new(self.key, AES.MODE_CBC, self.iv)
        padded_plaintext = pad(plaintext.encode('utf-8'), AES.block_size)
        ciphertext = cipher.encrypt(padded_plaintext)
        return base64.b64encode(ciphertext).decode('utf-8')

    def decrypt(self, ciphertext):
        """AES解密"""
        cipher = AES.new(self.key, AES.MODE_CBC, self.iv)
        decoded_ciphertext = base64.b64decode(ciphertext)
        padded_plaintext = cipher.decrypt(decoded_ciphertext)
        plaintext = unpad(padded_plaintext, AES.block_size)
        return plaintext.decode('utf-8')


class AESCipherEBC:
    """AES， EBC模式加密"""

    def __init__(self, key):
        """定义初始化加密和解密key和iv偏移"""
        self.key = key.encode('utf-8')

    def encrypt(self, plaintext):
        """AES加密"""
        cipher = AES.new(self.key, AES.MODE_ECB)
        cipherByte = cipher.encrypt(plaintext.encode('utf-8'))
        return base64.b64encode(cipherByte).decode('utf-8')

    def decrypt(self, ciphertext):
        """AES解密"""
        cipher = AES.new(self.key, AES.MODE_ECB)
        decoded_ciphertext = base64.b64decode(ciphertext)
        padded_plaintext = cipher.decrypt(decoded_ciphertext)
        return padded_plaintext.decode('utf-8')


class RSACipher:

    @staticmethod
    def generate(path):
        """RSA生成的路径"""
        private_path = Path.joinpath(path, "private.pem")
        public_path = Path.joinpath(path, "public.pem")
        if not os.path.exists(private_path) and not os.path.exists(public_path):
            rsa = RSA.generate(1024, Random.new().read)
            # 私钥生成并保存
            private_pem = rsa.exportKey()
            with open(private_path, "wb") as f:
                f.write(private_pem)

            # 公钥生成并保存
            public_pem = rsa.publickey().exportKey()
            with open(public_path, "wb") as f:
                f.write(public_pem)

    @staticmethod
    def encrypt(message: str, public_path=None, public_key=None):
        """RSA 加密（使用公钥加密）"""
        if public_path:
            public_key = RSA.importKey(open(public_path).read())
        cipher = PKCS1_v1_5.new(public_key)
        cipher_text = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
        return cipher_text.decode('utf-8')

    @staticmethod
    def decrypt(message: str, private_path=None, private_key=None):
        """RSA 解密（使用私钥解密）"""
        encrypt_text = message.encode('utf-8')
        if private_path:
            private_key = RSA.importKey(open(private_path).read())
        cipher = PKCS1_v1_5.new(private_key)
        text = cipher.decrypt(base64.b64decode(encrypt_text), "解密失败")
        return text.decode('utf-8')


class SM2Cipher:

    def __init__(self, public_key=None, private_key=None):
        """SM2国密 加解密（使用私钥解密）"""
        self.public_key = public_key
        self.private_key = private_key
        self.EncryptClient = sm2.CryptSM2(public_key=self.public_key, private_key=self.private_key)

    def encrypt(self, message: str) -> bytes:
        """SM2加密, 返回加密后的16进制字符串"""
        return self.EncryptClient.encrypt(message.encode('utf-8')).hex()

    def decrypt(self, message: str) -> str:
        """SM2解密, 加密后的字节码的16进制"""
        return self.EncryptClient.decrypt(bytes.fromhex(message)).decode('utf-8')


class Sm4Cipher:
    """Sm4 加解密算法"""

    class CipherModeEnum(Enum):
        ECB = "ecb"
        CBC = "cbc"

    def __init__(self, key, iv=None, cipher_mode: str = CipherModeEnum.CBC.value):
        """
        :param cipher_mode: ecb or cbc
        """
        if cipher_mode not in Sm4Cipher.CipherModeEnum._value2member_map_:
            raise ValueError('cipher mode just support cbc or ecb')
        self.cipher_mode = cipher_mode
        if self.cipher_mode == 'cbc':
            self.key = key.encode()
        self.iv = iv.encode()

    def encrypt(self, plaintext: str) -> str:
        """
        :param plaintext: 加密数据
        :return: byte的hex进制
        """

        enc_data = plaintext.encode()
        crypt_sm4 = sm4.CryptSM4()
        crypt_sm4.set_key(self.key, sm4.SM4_ENCRYPT)
        if self.cipher_mode == 'ecb':
            encrypt_value = crypt_sm4.crypt_ecb(enc_data)
        else:
            encrypt_value = crypt_sm4.crypt_cbc(self.iv, enc_data)
        return encrypt_value.hex()

    def decrypt(self, encrypted_text: str) -> str:
        """
        :param encrypted_text: 被加密数据
        :return: 字符串
        """
        crypt_sm4 = sm4.CryptSM4()
        crypt_sm4.set_key(self.key, sm4.SM4_DECRYPT)
        if self.cipher_mode == 'ecb':
            decrypt_value = crypt_sm4.crypt_ecb(bytes.fromhex(encrypted_text))
        else:
            decrypt_value = crypt_sm4.crypt_cbc(self.iv, bytes.fromhex(encrypted_text))
        return decrypt_value.decode(encoding='utf-8')


if __name__ == '__main__':
    # # hashlib模块加密
    # print(HashCipher.md5("asqwdasd"))
    #
    # # AES加密
    # key = "c-QULHn+u=-BUSQ$"
    # message1 = 'abcdefghijklmnhi'
    # cipher = AESCipherEBC(key)
    # encrypted_message = cipher.encrypt(message1)
    # print('Encrypted message:', encrypted_message)
    # decrypted_message = cipher.decrypt(encrypted_message)
    # print('Decrypted message:', decrypted_message)

    private_key = "6246191249ad96b78a697cc90b57d84a97a27845bb49c43d2314dd03ff3a4ca3",
    public_key = "049bacaa0c4b7d47a80aef9288121537b0ed835ae10e45fd5025224ad77ee7e4c7da23ef810d3ee206a7165959e6b9e65aaee6df617dd7980f111490c39605302b"
    smUs = SM2Cipher(public_key=public_key, private_key=private_key)

    s = smUs.encrypt("Hello, World!")
    print(s)
    print(smUs.decrypt(s))